import argparse
import multiprocessing

import torch

from centralized_verification.paths import RUNNING_ON_DISCOVERY
from experiments.utils.parallel_experiment import ParallelExperiment, CSVParallelExperiment

DEBUG = not RUNNING_ON_DISCOVERY
DEVICE = torch.device("cpu")
PARALLEL_NON_DEBUG = True
PARALLEL_DEBUG = False


def run_on_thread(experiment: ParallelExperiment, global_index: int):
    experiment.run_at_index(global_index)


def parse_args_and_run_parallel_csv_experiment(experiment_runner, tags):
    parser = argparse.ArgumentParser()
    parser.add_argument("--node-idx", dest="node_idx", type=int, required=True)
    parser.add_argument("--threads-per-node", dest="threads_per_node", type=int, required=True)
    parser.add_argument("filename", type=str)

    if DEBUG:
        tags.append("DEBUG")

    args = parser.parse_args()

    tags.append(args.filename)

    experiment = CSVParallelExperiment(args.filename, experiment_runner, tags)
    run_parallel_experiment(experiment, args.node_idx, args.threads_per_node)


def run_parallel_experiment(experiment: ParallelExperiment, node_idx: int, threads_per_node: int):
    start_idx = node_idx * threads_per_node
    for i in range(start_idx, start_idx + threads_per_node):
        if ((not DEBUG) and PARALLEL_NON_DEBUG) or (DEBUG and PARALLEL_DEBUG):
            p = multiprocessing.Process(target=run_on_thread, args=(experiment, i))
            p.start()
        else:
            run_on_thread(experiment, i)
